// Copyright (C) Microsoft Corporation. All rights reserved.

#pragma once
#include "WslCoreNetworkEndpointSettings.h"
#include "WslCoreHostDnsInfo.h"

namespace wsl::core::networking {
enum class TrackedIpStateSyncStatus
{
    PendingAdd,
    PendingUpdate,
    PendingRemoval,
    Synced
};
constexpr auto ToString(networking::TrackedIpStateSyncStatus status) noexcept
{
    switch (status)
    {
    case networking::TrackedIpStateSyncStatus::PendingAdd:
        return "PendingAdd";
    case networking::TrackedIpStateSyncStatus::PendingUpdate:
        return "PendingUpdate";
    case networking::TrackedIpStateSyncStatus::PendingRemoval:
        return "PendingRemoval";
    case networking::TrackedIpStateSyncStatus::Synced:
        return "Synced";
    default:
        return "Unknown";
    }
}

struct TrackedIpAddress
{
    EndpointIpAddress Address{};

    // The following fields need to be changed from a std::set iterator (which is always const)
    // in SyncIpStateWithLinux - that's why they're marked mutable.
    mutable TrackedIpStateSyncStatus SyncStatus = TrackedIpStateSyncStatus::PendingAdd;
    mutable uint32_t SyncRetryCount = MaxSyncRetryCount;
    mutable uint32_t LoopbackSyncRetryCount = MaxLoopbackSyncRetryCount;

    static constexpr uint32_t MaxSyncRetryCount = 15;
    static constexpr uint32_t MaxLoopbackSyncRetryCount = 5;

    TrackedIpAddress() = default;
    ~TrackedIpAddress() noexcept = default;

    // not copyable to avoid subtle bugs where 2 objects are trying to track state of the same address
    TrackedIpAddress(const TrackedIpAddress&) = delete;
    TrackedIpAddress& operator=(const TrackedIpAddress&) = delete;
    TrackedIpAddress(TrackedIpAddress&&) = default;
    TrackedIpAddress& operator=(TrackedIpAddress&&) = default;

    explicit TrackedIpAddress(const EndpointIpAddress& other)
    {
        Address = other;
    }

    TrackedIpAddress& operator=(const EndpointIpAddress& other)
    {
        Address = other;
        SyncStatus = TrackedIpStateSyncStatus::PendingAdd;
        SyncRetryCount = MaxSyncRetryCount;
        LoopbackSyncRetryCount = MaxLoopbackSyncRetryCount;
        return *this;
    }

    wsl::shared::hns::IPAddress ConvertToHnsSettingsMsg() const
    {
        wsl::shared::hns::IPAddress addr{};
        addr.Family = Address.Address.si_family;
        addr.Address = Address.AddressString;
        addr.OnLinkPrefixLength = Address.PrefixLength;
        addr.PreferredLifetime = Address.PreferredLifetime;
        addr.PrefixOrigin = Address.PrefixOrigin;
        addr.SuffixOrigin = Address.SuffixOrigin;
        return addr;
    }

    bool operator==(const TrackedIpAddress& other) const noexcept
    {
        return Address == other.Address;
    }

    bool operator!=(const TrackedIpAddress& other) const noexcept
    {
        return !(*this == other);
    }

    bool operator<(const TrackedIpAddress& other) const noexcept
    {
        return Address < other.Address;
    }
};

struct TrackedRoute
{
    EndpointRoute Route{};

    // The following fields need to be changed from a std::set iterator (which is always const)
    // in SyncIpStateWithLinux - that's why they're marked mutable.
    mutable TrackedIpStateSyncStatus SyncStatus = TrackedIpStateSyncStatus::PendingAdd;
    mutable uint32_t SyncRetryCount = MaxSyncRetryCount;
    mutable bool LinuxConflictRemoved = false; // only used for prefix routes

    static constexpr uint32_t MaxSyncRetryCount = 15;

    TrackedRoute() = default;
    ~TrackedRoute() noexcept = default;

    // not copyable to avoid subtle bugs where 2 objects are trying to track state of the same route
    TrackedRoute(const TrackedRoute&) = delete;
    TrackedRoute& operator=(const TrackedRoute&) = delete;
    TrackedRoute(TrackedRoute&&) = default;
    TrackedRoute& operator=(TrackedRoute&&) = default;

    explicit TrackedRoute(const EndpointRoute& other)
    {
        Route = other;
    }

    TrackedRoute& operator=(const EndpointRoute& other)
    {
        Route = other;
        SyncStatus = TrackedIpStateSyncStatus::PendingAdd;
        SyncRetryCount = MaxSyncRetryCount;
        LinuxConflictRemoved = false;
        return *this;
    }

    unsigned int LinuxAutoGenRouteMetric() const
    {
        return (Route.Family == AF_INET6) ? 1024 : 0;
    }

    bool CanConflictWithLinuxAutoGenRoute() const
    {
        return Route.IsAutoGeneratedPrefixRoute && (Route.Metric != LinuxAutoGenRouteMetric());
    }

    wsl::shared::hns::Route ConvertToHnsSettingsMsg() const
    {
        wsl::shared::hns::Route route{};
        route.Family = Route.Family;
        route.DestinationPrefix = Route.GetFullDestinationPrefix();
        route.SitePrefixLength = Route.SitePrefixLength;
        route.NextHop = Route.NextHopString;
        route.Metric = Route.Metric;
        return route;
    }

    bool operator==(const TrackedRoute& other) const noexcept
    {
        return Route == other.Route;
    }

    bool operator!=(const TrackedRoute& other) const noexcept
    {
        return !(*this == other);
    }

    bool operator<(const TrackedRoute& other) const noexcept
    {
        // return true if 'this' is less than (i.e. is ordered before) the input argument
        if (Route.IsAutoGeneratedPrefixRoute || other.Route.IsAutoGeneratedPrefixRoute)
        {
            if (Route.IsAutoGeneratedPrefixRoute && other.Route.IsAutoGeneratedPrefixRoute)
            {
                // if both are effectively equivalent, sort by their addresses
                if (Route.DestinationPrefixString == other.Route.DestinationPrefixString)
                {
                    return Route.Metric < other.Route.Metric;
                }
                return Route.DestinationPrefixString < other.Route.DestinationPrefixString;
            }
            // else return true if it's the left that's IsAutoGeneratedPrefixRoute
            return Route.IsAutoGeneratedPrefixRoute;
        }

        if (Route.IsNextHopOnlink() || other.Route.IsNextHopOnlink())
        {
            if (Route.IsNextHopOnlink() && other.Route.IsNextHopOnlink())
            {
                // if both are effectively equivalent, sort by their addresses
                if (Route.DestinationPrefixString == other.Route.DestinationPrefixString)
                {
                    return Route.Metric < other.Route.Metric;
                }
                return Route.DestinationPrefixString < other.Route.DestinationPrefixString;
            }
            // else return true if it's the left that's IsNextHopOnlink()
            return Route.IsNextHopOnlink();
        }

        // else it's an Add or Update for a route that's not an auto-generated route
        // and whose next-hop address is not on-link
        if (Route.DestinationPrefixString == other.Route.DestinationPrefixString)
        {
            return Route.Metric < other.Route.Metric;
        }
        return Route.DestinationPrefixString < other.Route.DestinationPrefixString;
    }
};

struct IpStateTracking
{
    bool InitialSyncComplete = false;

    GUID InterfaceGuid{};
    std::set<TrackedIpAddress> IpAddresses{};
    std::set<TrackedRoute> Routes{};
    std::vector<std::wstring> DnsServers{};
    // currently DnsServersSyncStatus is only tracked for tracing purposes - it's not being set through Linux
    TrackedIpStateSyncStatus DnsServersSyncStatus = TrackedIpStateSyncStatus::PendingAdd;
    DnsInfo DnsInfo{};
    ULONG InterfaceMtu = 0;
    bool IsMetered = false;

    IpStateTracking() = default;
    IpStateTracking(const std::optional<GUID>& vmCreatorId) : FirewallVmCreatorId(vmCreatorId)
    {
    }

    ~IpStateTracking() noexcept
    {
        if (FirewallVmCreatorId)
        {
            ResetState();
        }
    }

    // cannot copy - the d'tor clears FW rules so move-operators must track when moved-from
    // it does this by clearing the std::optional FirewallVmCreatorId when moved-from
    IpStateTracking(const IpStateTracking&) = delete;
    IpStateTracking& operator=(const IpStateTracking&) = delete;

    IpStateTracking(IpStateTracking&& lhs) noexcept :
        InitialSyncComplete(std::move(lhs.InitialSyncComplete)),
        InterfaceGuid(std::move(lhs.InterfaceGuid)),
        IpAddresses(std::move(lhs.IpAddresses)),
        Routes(std::move(lhs.Routes)),
        InterfaceMtu(std::move(lhs.InterfaceMtu)),
        IsMetered(std::move(lhs.IsMetered)),
        FirewallVmCreatorId(std::move(lhs.FirewallVmCreatorId)),
        FirewallTrackedIpAddresses(std::move(lhs.FirewallTrackedIpAddresses))
    {
        // must reset FirewallVmCreatorId so the d'tor won't reset FW state from the moved-from object
        lhs.FirewallVmCreatorId.reset();
    }

    IpStateTracking& operator=(IpStateTracking&& lhs) noexcept
    {
        InterfaceGuid = std::move(lhs.InterfaceGuid);
        IpAddresses = std::move(lhs.IpAddresses);
        Routes = std::move(lhs.Routes);
        InterfaceMtu = std::move(lhs.InterfaceMtu);
        IsMetered = std::move(lhs.IsMetered);
        InitialSyncComplete = std::move(lhs.InitialSyncComplete);
        FirewallVmCreatorId = std::move(lhs.FirewallVmCreatorId);
        FirewallTrackedIpAddresses = std::move(lhs.FirewallTrackedIpAddresses);

        // must reset FirewallVmCreatorId so the d'tor won't reset FW state from the moved-from object
        lhs.FirewallVmCreatorId.reset();

        return *this;
    }

    void ResetState() noexcept
    {
        WSL_LOG("IpStateTracking::ResetState");
        InitialSyncComplete = false;
        InterfaceGuid = {};
        IpAddresses.clear();
        Routes.clear();
        DnsServers.clear();
        DnsServersSyncStatus = TrackedIpStateSyncStatus::PendingAdd;
        InterfaceMtu = 0;
        IsMetered = false;
        SyncFirewallState({});
    }

    void SeedInitialState(const NetworkSettings& settings) noexcept
    {
        InterfaceGuid = settings.InterfaceGuid;
        InterfaceMtu = settings.GetEffectiveMtu();
        IsMetered = settings.IsMetered;

        WSL_LOG(
            "IpStateTracking::SeedInitialState",
            TraceLoggingValue(InterfaceGuid, "InterfaceGuid"),
            TraceLoggingValue(InterfaceMtu, "InterfaceMtu"),
            TraceLoggingValue(IsMetered, "IsMetered"),
            TraceLoggingValue(IpAddresses.size(), "IpAddresses.size()"),
            TraceLoggingValue(Routes.size(), "Routes.size()"),
            TraceLoggingValue(FirewallVmCreatorId.value_or(GUID{}), "FirewallVmCreatorId"));

        // not updating any other fields in this SeedInitialState
        // Address/Route/DnsServer objects are cleared when disconnected
        //   and updated as we confirm they are pushed to the container
    }

    void SyncFirewallState(const NetworkSettings& preferredNetwork) noexcept;

private:
    std::optional<GUID> FirewallVmCreatorId{};
    std::set<EndpointIpAddress> FirewallTrackedIpAddresses{};
};
} // namespace wsl::core::networking
